# runners/gif_hypersense_runner.py
from __future__ import annotations
from typing import Dict, Any
import time

import ConfigSpace as CS
import optuna
from optuna.distributions import FloatDistribution, IntDistribution, CategoricalDistribution

from objective import Objective
from loggers import ExperimentLogger
from runners.random_runner import _canonicalize

# —— Import your HyperSense components ——
from hypersense.pipeline import HyperSensePipeline
from hypersense.optimizer.optuna_optimizer import OptunaOptimizer
from hypersense.sampler.stratified_sampler import StratifiedSampler
from hypersense.importance.n_rrelieff import NRReliefFAnalyzer
from hypersense.strategy.greedy_important_first import GreedyImportantFirstStrategy


def _cs_to_hypersense_space(cs: CS.ConfigurationSpace):
    """
    Convert ConfigSpace -> HyperSense expected search_space (Optuna-style distributions)
    Returns: (space_dict, fixed_config)
    - space_dict: {name: Distribution} (only variable dimensions)
    - fixed_config: {name: value} (constants or single choices, fixed directly)
    """
    space: Dict[str, Any] = {}
    fixed: Dict[str, Any] = {}

    for hp in cs.values():  # CS 1.2+ API
        name = hp.name

        # Constant / single category -> fixed
        try:
            from ConfigSpace.hyperparameters import Constant, CategoricalHyperparameter, \
                UniformIntegerHyperparameter, UniformFloatHyperparameter, \
                IntegerHyperparameter, FloatHyperparameter
        except Exception:
            # Compatible with older naming
            from ConfigSpace.hyperparameters import Constant, CategoricalHyperparameter, \
                UniformIntegerHyperparameter, UniformFloatHyperparameter, \
                IntegerHyperparameter, FloatHyperparameter

        if isinstance(hp, Constant):
            fixed[name] = hp.value
            continue
        if isinstance(hp, CategoricalHyperparameter) and len(hp.choices) == 1:
            fixed[name] = hp.choices[0]
            continue

        # Categorical
        if isinstance(hp, CategoricalHyperparameter):
            space[name] = CategoricalDistribution(choices=list(hp.choices))
            continue

        # Integer
        if isinstance(hp, (UniformIntegerHyperparameter, IntegerHyperparameter)):
            low, high = int(hp.lower), int(hp.upper)
            log = bool(getattr(hp, "log", False) or getattr(hp, "log_scale", False))
            step = getattr(hp, "q", None) or 1
            space[name] = IntDistribution(low=low, high=high, log=log, step=step)
            continue

        # Float
        if isinstance(hp, (UniformFloatHyperparameter, FloatHyperparameter)):
            low, high = float(hp.lower), float(hp.upper)
            log = bool(getattr(hp, "log", False) or getattr(hp, "log_scale", False))
            step = getattr(hp, "q", None)
            space[name] = FloatDistribution(low=low, high=high, log=log, step=step)
            continue

        raise NotImplementedError(f"Unsupported HP type in HyperSense space: {type(hp).__name__} ({name})")

    return space, fixed


def run_gif_hypersense(*,
                       seed: int,
                       bench: str,
                       cs: CS.ConfigurationSpace,
                       obj: Objective,
                       budget_n: int,
                       logger: ExperimentLogger,
                       method_name: str = "GIF-HyperSense",
                       # Key HyperSense parameters (can be passed through main.py for tuning)
                       warmup_n: int = 300,
                       sample_ratio: float = 0.6,
                       top_k: int | None = None,
                       step_trials: int | None = None,
                       min_trials_for_importance: int = 10,
                       full_group_ratio: float = 0.2,
                       quiet: bool = True,
                       verbose: bool = False):
    """
    Run HyperSense's GreedyImportantFirst strategy on NASBench301.
    - Use Objective.evaluate to get loss & accumulate sim_time
    - Each call to objective_fn writes a unified CSV log entry
    """
    # 1) Convert search space
    hs_space, fixed = _cs_to_hypersense_space(cs)
    dim = len(hs_space)

    if top_k is None:
        # Default: select 1/3 of dimensions
        top_k = max(1, dim // 3)
    if step_trials is None:
        # Evaluation step size: use number of dimensions
        step_trials = max(1, dim)

    # 2) Default config: use CS default (if not available, sample once)
    try:
        default_cfg = dict(cs.get_default_configuration())
    except Exception:
        default_cfg = dict(cs.sample_configuration())
    # Merge fixed into default
    default_cfg.update(fixed)

    # 3) Construct HyperSense objective function (calls our wrapped Objective.evaluate)
    n_eval = 0
    best = float("inf")

    def nb301_objective(config: Dict[str, Any], dataset=None):
        nonlocal n_eval, best

        # Merge fixed items, ensure NB301 receives complete config
        merged = dict(config)
        merged.update(fixed)
        
        try:
            merged = _canonicalize(cs, merged)
        except Exception as e:
            return 1e9

        t0 = time.perf_counter()
        loss, sim_t = obj.evaluate(merged)            # NB301 surrogate + accumulate sim_time
        elapsed = time.perf_counter() - t0

        n_eval += 1
        if loss < best:
            best = loss

        # Unified logging to CSV (same fields as Random/TPE/BOHB)
        logger.log(dict(
            seed=seed, method=method_name, bench=bench,
            n_eval=n_eval,
            sim_time=sim_t,
            elapsed_time=elapsed,
            best_score=1 - best,          # As before: score = 1 - loss
            curr_score=1 - loss,
            config=merged,
        ))
        # HyperSense pipeline will handle mode (we set mode='min')
        return loss

    # 4) Fake a "useless dataset" (interface required but not used)
    #    Just provide an empty list; if your HyperSense needs shape, sample a few points and drop labels.
    full_dataset = []

    # 5) Instantiate HyperSensePipeline (GIF strategy)
    pipeline = HyperSensePipeline(
        search_space=hs_space,                    # Optuna-style distributions
        full_dataset=full_dataset,
        objective_fn=nb301_objective,
        test_fn=nb301_objective,                 # Use the same (we only care about validation loss)
        sampler_class=StratifiedSampler,
        initial_optimizer_class=OptunaOptimizer, # Initial small sample
        whole_optimizer_class=OptunaOptimizer,   # Subsequent global optimization
        importance_analyzer_class=NRReliefFAnalyzer,
        default_config=default_cfg,
        mode="min",                              # ← NASBench's goal is to minimize error
        best_known_optimum=None,                 # Optional
        seed=seed,
        strategy_class=GreedyImportantFirstStrategy,
    )

    # 6) Run: small sample + GIF
    #    Note: total_trials=total budget; HyperSense will run initial_trials first, then remaining rounds
    pipeline.run(
        sample_ratio=sample_ratio,
        initial_trials=warmup_n,
        total_trials=budget_n,
        verbose=verbose,
        quiet=quiet,
        step_trials=step_trials,
        min_trials_for_importance=min_trials_for_importance,
        full_group_ratio=full_group_ratio,
        top_k=top_k,
    )

